{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "afd0c153-6e87-4446-a38e-f63da8cc360f",
   "metadata": {},
   "source": [
    "### Run Llama finetuning code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd6ef5d-2a76-4e3c-bdd5-415288b74c99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from trl import SFTTrainer\n",
    "from datasets import load_dataset\n",
    "from transformers import TrainingArguments, TextStreamer\n",
    "from unsloth.chat_templates import get_chat_template\n",
    "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
    "import json\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd39f01-ce4e-4185-83d3-c6c2936a440f",
   "metadata": {},
   "source": [
    "#### Define the file path, model name, save name, and epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d74b505c-1421-4608-abe2-98ace7b755e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = \"/home/users/seunghh/l2m3_revision/data/text_categorize.jsonl\"\n",
    "model_name = \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\"\n",
    "save_name = 'text_categorize_1'\n",
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fefb95-90d5-4944-8bc9-6e36b6faca97",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_files(file_name):\n",
    "    # Reads a JSONL file line-by-line, parses each line as JSON, and stores it in a list\n",
    "    data_list = []\n",
    "    with open(file_name, 'r') as g:\n",
    "        for line_number, line in enumerate(g, start=1):\n",
    "            line = line.strip()  # Remove any trailing or leading whitespace\n",
    "            if line:  # Skip empty lines\n",
    "                try:\n",
    "                    json_obj = json.loads(line)\n",
    "                    data_list.append(json_obj)\n",
    "                except json.JSONDecodeError as e:\n",
    "                    # Print an error message if a line cannot be parsed as JSON\n",
    "                    print(f\"Error parsing JSON on line {line_number}: {e}\")    \n",
    "    return data_list\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "000fb7b3-f9e1-4e17-bf52-5c22cbdf530e",
   "metadata": {},
   "source": [
    "#### Define Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fa8bf6b-5038-4134-9acb-518e03522022",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read and prepare the training data\n",
    "data = read_files(file_name)\n",
    "\n",
    "# Set max sequence length for the model\n",
    "max_seq_length = 2048\n",
    "\n",
    "# Load the pretrained model and tokenizer using a FastLanguageModel utility\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=model_name,\n",
    "    max_seq_length=max_seq_length,\n",
    "    load_in_4bit=True,\n",
    "    dtype=None,\n",
    ")\n",
    "\n",
    "# Prepare the model for Parameter-Efficient Fine-Tuning (PEFT) using LoRA\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    lora_dropout=0,\n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"up_proj\", \"down_proj\", \"o_proj\", \"gate_proj\"],\n",
    "    use_rslora=True,\n",
    "    use_gradient_checkpointing=\"unsloth\"\n",
    ")\n",
    "\n",
    "# Print the number of trainable parameters in the model\n",
    "print(model.print_trainable_parameters())\n",
    "\n",
    "# Retrieve a custom chat template for the tokenizer\n",
    "tokenizer = get_chat_template(\n",
    "    tokenizer,\n",
    ")\n",
    "\n",
    "def apply_template(examples):\n",
    "    # Apply a chat template to each example in the dataset\n",
    "    # This prepares the examples in a prompt-like format for training\n",
    "    message = examples[\"messages\"]\n",
    "    text = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)\n",
    "    examples['text'] = text\n",
    "    return examples\n",
    "\n",
    "# Apply the chat template to all training data\n",
    "dataset = list(map(apply_template, data))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "990b07e9-789e-4350-b47e-15b1d72fbc92",
   "metadata": {},
   "source": [
    "#### Set up the Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de345e4b-605e-4298-bdfe-ae9893faeaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Set up the trainer for supervised fine-tuning\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    train_dataset=dataset,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    dataset_num_proc=2,\n",
    "    packing=True,\n",
    "    args=TrainingArguments(\n",
    "        learning_rate=3e-4,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        per_device_train_batch_size=4,\n",
    "        gradient_accumulation_steps=4,\n",
    "        num_train_epochs=num_epochs,\n",
    "        fp16=not is_bfloat16_supported(),\n",
    "        bf16=is_bfloat16_supported(),\n",
    "        logging_steps=1,\n",
    "        optim=\"adamw_8bit\",\n",
    "        weight_decay=0.01,\n",
    "        warmup_steps=10,\n",
    "        output_dir=\"output\",\n",
    "        seed=0,\n",
    "        logging_dir='./logs',\n",
    "        report_to=['tensorboard'],\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ee6b26c-405e-4877-b674-8d7ce3ca9c55",
   "metadata": {},
   "source": [
    "#### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c025ce77-4c9f-4e28-8459-56e2989923e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start the training process\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "544d4c0c-fb3d-4906-bbe0-abcddd131d44",
   "metadata": {},
   "source": [
    "#### Save the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8ec55ec-cc02-4f0b-a1c0-d2f1029eea57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final merged model and its tokenizer\n",
    "model.save_pretrained_merged(save_name, tokenizer, save_method=\"merged_16bit\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "383dd725-5db9-4e57-8d44-58fcc6d05851",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
